{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9ffd025a",
   "metadata": {},
   "source": [
    "## 1. 数据准备"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "015f8d5b",
   "metadata": {},
   "source": [
    "### 1.1 代码包引入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "200bc3bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.utils.data as Data\n",
    "from torch import optim\n",
    "import numpy as np\n",
    "from tqdm import *\n",
    "import matplotlib.pyplot as plt\n",
    "import re\n",
    "import string\n",
    "from collections import Counter  # 计数类"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "271a3a91",
   "metadata": {},
   "source": [
    "### 1.2 数据整理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "24ea176c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['是时候教教你针线活了\\t手把手教\\t别，不用。\\n',\n",
       " '谢谢你所做的一切\\t你开心就好\\t开心\\t嗯因为你的心里只有学习\\t某某某，还有你\\t这个某某某用的好\\t\\n',\n",
       " '今天好点了吗？\\t一天比一天严重\\t吃药不管用，去打一针。别拖着\\t\\n',\n",
       " '加油，三月动起来，五月笑起来\\t正解你为什么就那么厉害呢\\t哈哈，没办法，智商就是这么高\\t你这是要开始得瑟了吗！好啦！你最厉害！\\t哈哈哈哈\\t\\n',\n",
       " '因为我网络差吗，加载不出来啊\\t是什么网络啊，移动的可能比较慢\\t啊真的是我自家网络的问题，用别家的就好了\\t\\n',\n",
       " '这个側颜可以\\t这个繁体很厉害\\t为了配你的颜\\t哥，\\t\\n',\n",
       " '对啊。你以为你是谁。留言就要回复吗\\t对啊，什么逻辑\\t大概是\\n',\n",
       " '保留几分\\t晾毛巾晾毛巾叫你天天晾毛巾\\t看看几点了还不快去吃饭\\t\\t\\n',\n",
       " '哈哈哈哈哈哈哈哈哈哈哈哈\\t你哈哈个啥\\t别人晒得都是某某某搞事情为啥只有你是瘦了！太尼玛懂你了笑哭我了\\t这个太假了\\t\\n',\n",
       " '看着就难吃\\t又不是你吃，莫操心\\t心疼你\\t滚蛋\\t不滚\\n']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with open('data/dataset.txt','r',encoding='utf-8') as f:\n",
    "    datas = f.readlines()\n",
    "datas[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "714abcf0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'，', ' ', '\\n', \"'\", '？', '》', '《', '.', '。', '！', ',', '\"', ';', '!', '、', '?', '’', '\\t', '‘', '；'}\n"
     ]
    }
   ],
   "source": [
    "# 查看特殊字符\n",
    "content = ''.join(datas)\n",
    "special_char = re.sub(r'[\\u4e00-\\u9fa5]', ' ', content)  # 匹配中文，将中文替换掉\n",
    "\n",
    "print(set(special_char) - set(string.ascii_letters) - set(string.digits))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "80c18643",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tokens: [['是', '时', '候', '教', '教', '你', '针', '线', '活', '了', '<sep>', '手', '把', '手', '教', '<sep>', '别', '，', '不', '用', '。', '<sep>'], ['谢', '谢', '你', '所', '做', '的', '一', '切', '<sep>', '你', '开', '心', '就', '好', '<sep>', '开', '心', '<sep>', '嗯', '因', '为', '你', '的', '心', '里', '只', '有', '学', '习', '<sep>', '某', '某', '某', '，', '还', '有', '你', '<sep>', '这', '个', '某', '某', '某', '用', '的', '好', '<sep>'], ['今', '天', '好', '点', '了', '吗', '？', '<sep>', '一', '天', '比', '一', '天', '严', '重', '<sep>', '吃', '药', '不', '管', '用', '，', '去', '打', '一', '针', '。', '别', '拖', '着', '<sep>'], ['加', '油', '，', '三', '月', '动', '起', '来', '，', '五', '月', '笑', '起', '来', '<sep>', '正', '解', '你', '为', '什', '么', '就', '那', '么', '厉', '害', '呢', '<sep>', '哈', '哈', '，', '没', '办', '法', '，', '智', '商', '就', '是', '这', '么', '高', '<sep>', '你', '这', '是', '要', '开', '始', '得', '瑟', '了', '吗', '！', '好', '啦', '！', '你', '最', '厉', '害', '！', '<sep>', '哈', '哈', '哈', '哈', '<sep>'], ['因', '为', '我', '网', '络', '差', '吗', '，', '加', '载', '不', '出', '来', '啊', '<sep>', '是', '什', '么', '网', '络', '啊', '，', '移', '动', '的', '可', '能', '比', '较', '慢', '<sep>', '啊', '真', '的', '是', '我', '自', '家', '网', '络', '的', '问', '题', '，', '用', '别', '家', '的', '就', '好', '了', '<sep>'], ['这', '个', '側', '颜', '可', '以', '<sep>', '这', '个', '繁', '体', '很', '厉', '害', '<sep>', '为', '了', '配', '你', '的', '颜', '<sep>', '哥', '，', '<sep>']]\n"
     ]
    }
   ],
   "source": [
    "# 词元化\n",
    "def tokenize(datas):\n",
    "    # 存储词元\n",
    "    tokens = []\n",
    "    for data in datas:\n",
    "        data=data.strip().replace(\"\\n\",\"\")\n",
    "        token = [i if i!='\\t' else \"<sep>\" for i in data]+['<sep>']\n",
    "        tokens.append(token)\n",
    "    return tokens\n",
    "\n",
    "tokens = tokenize(datas)\n",
    "print(\"tokens:\", tokens[:6])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dcea6c3",
   "metadata": {},
   "source": [
    "### 1.3 构建词表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "71488d94",
   "metadata": {},
   "outputs": [],
   "source": [
    "flatten = lambda l: [item for sublist in l for item in sublist]  # 展平数组\n",
    "# 构建词表\n",
    "class Vocab:\n",
    "    def __init__(self, tokens):\n",
    "        self.tokens = tokens  # 传入的tokens是二维列表\n",
    "        self.token2index = {'<pad>': 0, '<unk>': 1, '<seq>': 2}  # 先存好特殊词元\n",
    "        # 将词元按词频排序后生成列表\n",
    "        self.token2index.update({\n",
    "            token: index + 3\n",
    "            for index, (token, freq) in enumerate(\n",
    "                sorted(Counter(flatten(self.tokens)).items(), key=lambda x: x[1], reverse=True))\n",
    "        })\n",
    "        # 构建id到词元字典\n",
    "        self.index2token = {index: token for token, index in self.token2index.items()}\n",
    "\n",
    "    def __getitem__(self, query):\n",
    "        # 单一索引\n",
    "        if isinstance(query, (str, int)):\n",
    "            if isinstance(query, str):\n",
    "                return self.token2index.get(query, 0)\n",
    "            elif isinstance(query, (int)):\n",
    "                return self.index2token.get(query, '<unk>')\n",
    "        # 数组索引\n",
    "        elif isinstance(query, (list, tuple)):\n",
    "            return [self.__getitem__(item) for item in query]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.index2token)\n",
    "\n",
    "#实例化词表\n",
    "vocab = Vocab(tokens)\n",
    "vocab_size = len(vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f1775af",
   "metadata": {},
   "source": [
    "### 1.4 构造数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "46c5ec30",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构建自己的数据集类\n",
    "class MyDataSet(Data.Dataset):\n",
    "    def __init__(self,datas):\n",
    "        self.datas = datas\n",
    "\n",
    "    def __getitem__(self, item):\n",
    "        data = self.datas[item]\n",
    "        decoder_input = data[:-1]\n",
    "        decoder_output = data[1:]\n",
    "\n",
    "        decoder_input_len = len(decoder_input)\n",
    "        decoder_output_len = len(decoder_output)\n",
    "\n",
    "        return {\"decoder_input\":decoder_input,\"decoder_input_len\":decoder_input_len,\n",
    "                \"decoder_output\":decoder_output,\"decoder_output_len\":decoder_output_len}\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.datas)\n",
    "\n",
    "    def padding_batch(self,batch):\n",
    "        # 批处理, 对每个批次的数据中，长度不够的序列填充<pad>\n",
    "        decoder_input_lens = [d[\"decoder_input_len\"] for d in batch]\n",
    "        decoder_output_lens = [d[\"decoder_output_len\"] for d in batch]\n",
    "\n",
    "        decoder_input_maxlen = max(decoder_input_lens)\n",
    "        decoder_output_maxlen = max(decoder_output_lens)\n",
    "\n",
    "        for d in batch:\n",
    "            d[\"decoder_input\"].extend([vocab[\"<pad>\"]]*(decoder_input_maxlen-d[\"decoder_input_len\"]))\n",
    "            d[\"decoder_output\"].extend([vocab[\"<pad>\"]]*(decoder_output_maxlen-d[\"decoder_output_len\"]))\n",
    "        decoder_inputs = torch.tensor([d[\"decoder_input\"] for d in batch], dtype=torch.long)\n",
    "        decoder_outputs = torch.tensor([d[\"decoder_output\"] for d in batch], dtype=torch.long)\n",
    "        return decoder_inputs,decoder_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "630d9554",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "# 构造数据集\n",
    "tokens_num = [[vocab[word] for word in line] for line in tokens] # 文本序列转id序列\n",
    "dataset = MyDataSet(tokens_num)\n",
    "data_loader = Data.DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.padding_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b32974cf",
   "metadata": {},
   "source": [
    "## 2. 建立模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bf674cc",
   "metadata": {},
   "source": [
    "### 2.1 掩码操作"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "535c160a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# mask掉没有意义的占位符\n",
    "def get_attn_pad_mask(seq_q, seq_k):                       # seq_q: [batch_size, seq_len] ,seq_k: [batch_size, seq_len]\n",
    "    batch_size, len_q = seq_q.size()\n",
    "    batch_size, len_k = seq_k.size()\n",
    "    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)          # 判断 输入那些含有P(=0),用1标记 ,[batch_size, 1, len_k]\n",
    "    return pad_attn_mask.expand(batch_size, len_q, len_k)\n",
    "\n",
    "# mask掉未来信息\n",
    "def get_attn_subsequence_mask(seq):                               # seq: [batch_size, tgt_len]\n",
    "    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]\n",
    "    subsequence_mask = np.triu(np.ones(attn_shape), k=1)          # 生成上三角矩阵,[batch_size, tgt_len, tgt_len]\n",
    "    subsequence_mask = torch.from_numpy(subsequence_mask).byte()  #  [batch_size, tgt_len, tgt_len]\n",
    "    subsequence_mask = subsequence_mask.to(device)\n",
    "    return subsequence_mask"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9d8ba51",
   "metadata": {},
   "source": [
    "### 2.2 注意力计算函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e3e9e702",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 缩放点积注意力计算\n",
    "class ScaledDotProductAttention(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ScaledDotProductAttention, self).__init__()\n",
    "    def forward(self, Q, K, V, attn_mask):\n",
    "        '''\n",
    "        Q: [batch_size, n_heads, len_q, d_k]\n",
    "        K: [batch_size, n_heads, len_k, d_k]\n",
    "        V: [batch_size, n_heads, len_v(=len_k), d_v]\n",
    "        attn_mask: [batch_size, n_heads, seq_len, seq_len]\n",
    "        '''\n",
    "        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, len_q, len_k]\n",
    "        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True.\n",
    "        attn = nn.Softmax(dim=-1)(scores)\n",
    "        context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]\n",
    "        return context, attn\n",
    "\n",
    "#多头注意力计算\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MultiHeadAttention, self).__init__()\n",
    "        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)\n",
    "        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)\n",
    "        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)\n",
    "        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)\n",
    "        self.layernorm = nn.LayerNorm(d_model)\n",
    "    def forward(self, input_Q, input_K, input_V, attn_mask):\n",
    "        '''\n",
    "        input_Q: [batch_size, len_q, d_model]\n",
    "        input_K: [batch_size, len_k, d_model]\n",
    "        input_V: [batch_size, len_v(=len_k), d_model]\n",
    "        attn_mask: [batch_size, seq_len, seq_len]\n",
    "        '''\n",
    "        residual, batch_size = input_Q, input_Q.size(0)\n",
    "        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)\n",
    "        Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) # Q: [batch_size, n_heads, len_q, d_k]\n",
    "        K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2) # K: [batch_size, n_heads, len_k, d_k]\n",
    "        V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2) # V: [batch_size, n_heads, len_v(=len_k), d_v]\n",
    "        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]\n",
    "        # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]\n",
    "        context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)\n",
    "        context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v) # context: [batch_size, len_q, n_heads * d_v]\n",
    "        output = self.fc(context) # [batch_size, len_q, d_model]\n",
    "        return self.layernorm(output + residual), attn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46ff28f5",
   "metadata": {},
   "source": [
    "### 2.3 构建前馈网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9136dbb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class PoswiseFeedForwardNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(PoswiseFeedForwardNet, self).__init__()\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(d_model, d_ff, bias=False),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(d_ff, d_model, bias=False))\n",
    "        self.layernorm = nn.LayerNorm(d_model)\n",
    "\n",
    "    def forward(self, inputs):  # inputs: [batch_size, seq_len, d_model]\n",
    "        residual = inputs\n",
    "        output = self.fc(inputs)\n",
    "        return self.layernorm(output + residual)  # 残差 + LayerNorm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "edddc1ad",
   "metadata": {},
   "source": [
    "### 2.4 解码器模块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "709d55a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 解码器层\n",
    "class DecoderLayer(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(DecoderLayer, self).__init__()\n",
    "        self.dec_self_attn = MultiHeadAttention()\n",
    "        self.pos_ffn = PoswiseFeedForwardNet()\n",
    "\n",
    "    def forward(self, dec_inputs, dec_self_attn_mask):\n",
    "        '''\n",
    "        dec_inputs: [batch_size, tgt_len, d_model]\n",
    "        dec_self_attn_mask: [batch_size, tgt_len, tgt_len]\n",
    "        '''\n",
    "        # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]\n",
    "        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)\n",
    "        dec_outputs = self.pos_ffn(dec_outputs)  # [batch_size, tgt_len, d_model]\n",
    "        return dec_outputs, dec_self_attn\n",
    "\n",
    "\n",
    "# 解码器模块\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.tgt_emb = nn.Embedding(vocab_size, d_model)\n",
    "        self.pos_emb = nn.Embedding(seq_len, d_model)\n",
    "        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])\n",
    "\n",
    "    def forward(self, dec_inputs):\n",
    "        '''\n",
    "        dec_inputs: [batch_size, tgt_len]\n",
    "        '''\n",
    "        # 构建position embedding\n",
    "        seq_len = dec_inputs.size(1)\n",
    "        pos = torch.arange(seq_len, dtype=torch.long, device=device)\n",
    "        pos = pos.unsqueeze(0).expand_as(dec_inputs)  # [seq_len] -> [batch_size, seq_len]\n",
    "\n",
    "        word_emb = self.tgt_emb(dec_inputs)  # [batch_size, tgt_len, d_model]\n",
    "        pos_emb = self.pos_emb(pos)  # [batch_size, tgt_len, d_model]\n",
    "        dec_outputs = word_emb + pos_emb\n",
    "\n",
    "        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs)  # [batch_size, tgt_len, tgt_len]\n",
    "        dec_self_attn_subsequent_mask = get_attn_subsequence_mask(dec_inputs)  # [batch_size, tgt_len]\n",
    "        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask),\n",
    "                                      0)  # [batch_size, tgt_len, tgt_len]\n",
    "\n",
    "        dec_self_attns = []\n",
    "        for layer in self.layers:\n",
    "            # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]\n",
    "            dec_outputs, dec_self_attn = layer(dec_outputs, dec_self_attn_mask)\n",
    "            dec_self_attns.append(dec_self_attn)\n",
    "\n",
    "        return dec_outputs, dec_self_attns"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae53c38e",
   "metadata": {},
   "source": [
    "### 2.5 GPT模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ded899f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPT(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GPT, self).__init__()\n",
    "        self.decoder = Decoder()\n",
    "        self.projection = nn.Linear(d_model, vocab_size, bias=False)\n",
    "\n",
    "    def forward(self, dec_inputs):\n",
    "        \"\"\"\n",
    "        dec_inputs: [batch_size, tgt_len]\n",
    "        \"\"\"\n",
    "\n",
    "        # dec_outpus: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len]\n",
    "        dec_outputs, dec_self_attns = self.decoder(dec_inputs)\n",
    "        # dec_logits: [batch_size, tgt_len, tgt_vocab_size]\n",
    "        dec_logits = self.projection(dec_outputs)\n",
    "        return dec_logits.view(-1, dec_logits.size(-1)), dec_self_attns\n",
    "\n",
    "    def answer(self, above): # 生成回复\n",
    "        \n",
    "        dec_input = [vocab[word] for word in above]\n",
    "        dec_input.append(vocab['<sep>']) # 原始句子后面增加<sep>\n",
    "        dec_input = torch.tensor(dec_input, dtype=torch.long, device=device).unsqueeze(0)\n",
    "\n",
    "        # 循环生成下一个单词\n",
    "        for i in range(100):\n",
    "\n",
    "            dec_outputs, _ = self.decoder(dec_input)\n",
    "            projected = self.projection(dec_outputs)\n",
    "            prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]\n",
    "            next_id = prob.data[-1]\n",
    "\n",
    "            if next_id == vocab[\"<sep>\"]: # 到出现\"<sep>\"结束\n",
    "                break\n",
    "\n",
    "            dec_input = torch.cat(\n",
    "                [dec_input.detach(), torch.tensor([[next_id]], dtype=dec_input.dtype, device=device)], -1)\n",
    "\n",
    "        output = dec_input.squeeze(0)\n",
    "        sequence = [vocab[int(id)] for id in output] # id转文字\n",
    "\n",
    "        answer = \"\".join(sequence)\n",
    "        answer = answer[answer.rindex(\"<sep>\")+5: ] # 取最后一个<sep>后面部分, +5 是加上<seq>本身的长度\n",
    "\n",
    "        return answer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30e9e92c",
   "metadata": {},
   "source": [
    "## 3. 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "fabd71ab",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [18:04<00:00,  3.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 4.293\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [18:05<00:00,  3.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 3.807\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:51<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 3.611\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:50<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 3.466\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:51<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 3.340\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:50<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 3.223\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:51<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 3.113\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:50<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 3.010\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:50<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.918\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:50<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.837\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:49<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.767\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:50<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.704\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:49<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.648\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:50<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.595\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:49<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.546\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:50<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.499\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:49<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.457\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:49<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.416\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:50<00:00,  3.64it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.378\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:51<00:00,  3.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.342\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [17:52<00:00,  3.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.308\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [18:05<00:00,  3.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.276\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [18:08<00:00,  3.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.245\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [18:08<00:00,  3.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.216\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [18:08<00:00,  3.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.188\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [18:08<00:00,  3.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.163\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [18:08<00:00,  3.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.138\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [18:08<00:00,  3.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.114\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [18:08<00:00,  3.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.091\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3895/3895 [18:09<00:00,  3.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\tTrain Loss: 2.070\n"
     ]
    }
   ],
   "source": [
    "# 定义超参数\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "seq_len = 300  # 序列最大长度\n",
    "d_model = 768  # Embedding维度\n",
    "d_ff = 2048  # 前馈层维度\n",
    "d_k = d_v = 64  # QKV维度\n",
    "n_layers = 6  # 解码器层数\n",
    "n_heads = 8  # 多头注意力头数\n",
    "batch_size = 64\n",
    "epochs = 30\n",
    "\n",
    "# 定义模型\n",
    "model = GPT().to(device)\n",
    "# 定义损失函数和优化器\n",
    "criterion = nn.CrossEntropyLoss(ignore_index=0).to(device)\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-4)\n",
    "\n",
    "loss_history = [] # 记录损失变化\n",
    "for epoch in range(epochs):\n",
    "    model.train()\n",
    "    epoch_loss = 0\n",
    "    for i, (dec_inputs, dec_outputs) in enumerate(tqdm(data_loader)):\n",
    "        optimizer.zero_grad()\n",
    "        dec_inputs, dec_outputs =dec_inputs.to(device), dec_outputs.to(device)\n",
    "        # outputs: [batch_size * tgt_len, tgt_vocab_size]\n",
    "        outputs, dec_self_attns = model(dec_inputs)\n",
    "\n",
    "        loss = criterion(outputs, dec_outputs.view(-1))\n",
    "        epoch_loss += loss.item()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    train_loss = epoch_loss / len(data_loader)\n",
    "    loss_history.append(train_loss)   # 记录损失变化\n",
    "    print(f'\\tTrain Loss: {train_loss:.3f}')\n",
    "    torch.save(model.state_dict(), 'model/gpt_chat.pt') # 保存模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "6b5eb402",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjkAAAGdCAYAAADwjmIIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABAhklEQVR4nO3dd3QVdeL+8WfSSSchlQRI6L0kKEGKgoBgwdV11bWgrrrYFdx1QfenW3F3v2vBAqtiL7gaUFxFwZKAAkIgQOgtkJBCCKQRSL3z+yMSjbQbuMnce/N+nXPPIXMnycOcOdyHmc98PoZpmqYAAADcjIfVAQAAAFoCJQcAALglSg4AAHBLlBwAAOCWKDkAAMAtUXIAAIBbouQAAAC3RMkBAABuycvqAK3NZrMpPz9fQUFBMgzD6jgAAMAOpmmqoqJCsbGx8vCw7xpNmys5+fn5io+PtzoGAAA4C7m5uYqLi7Nr3zZXcoKCgiQ1HKTg4GCL0wAAAHuUl5crPj6+8XPcHm2u5By/RRUcHEzJAQDAxTRnqAkDjwEAgFui5AAAALdEyQEAAG6JkgMAANwSJQcAALglSg4AAHBLlBwAAOCWKDkAAMAtUXIAAIBbouQAAAC3RMkBAABuiZIDAADcUptboLOlFFVU6d3vc1RSWaM/Te5ndRwAANo8ruQ4SF29qWe+3Km3v89RVW291XEAAGjzKDkOEhPip4ggX9XbTG3OL7M6DgAAbR4lx0EMw9DAuFBJ0vpcSg4AAFaj5DjQoPgQSdKG3FJrgwAAAEqOIw2Kby9J2rC/1NogAACAkuNI/eMaruTsO3RUJZU1FqcBAKBto+Q4UEg7byVGBEjiag4AAFaj5DjYoMbBx6WW5gAAoK2j5DjYwPhQSQw+BgDAapQcB2ssOfvLZJqmtWEAAGjDKDkO1jsmSN6ehg5X1mh/yTGr4wAA0GZRchzM18tTfWKCJTEuBwAAK1FyWsAgxuUAAGA5Sk4L+HFcTqmlOQAAaMsoOS3geMnJyitTbb3N2jAAALRRlJwWkBAeoCA/L1XV2rTjQIXVcQAAaJMoOS3Aw+PHFck3sCI5AACWoOS0kIGsSA4AgKUoOS2k8UoOg48BALAEJaeFHH+MfMeBClVW11kbBgCANoiS00Iig/0UG+InmyltymNcDgAArY2S04KYLwcAAOtQclrQ8ZLD8g4AALQ+Sk4L4jFyAACsQ8lpQf3jQmQYUl7pMRVVVFkdBwCANoWS04ICfb3UPTJQkrSRqzkAALQqSk4LY74cAACsQclpYYM6hUpi8DEAAK2NktPCfhx8XCrTNK0NAwBAG0LJaWE9o4Pk6+Wh8qo6ZRdXWh0HAIA2g5LTwrw9PdSv4w+LdTIuBwCAVkPJaQXMlwMAQOuj5LSCgfENV3IYfAwAQOuh5LSC4yuSb8kvV02dzdowAAC0EZScVtApzF/t/b1VU2/TtsJyq+MAANAmUHJagWEYP65Izi0rAABaBSWnlRwffJxJyQEAoFVQclrJIK7kAADQqig5rWRAXMMTVrsPVqq8qtbiNAAAuD9KTisJD/RVfFg7SVLWfubLAQCgpVFyWtHxcTnMlwMAQMtzmpIza9YsGYahBx988LT7paenKykpSX5+fkpMTNTcuXNbJ6ADMC4HAIDW4xQlZ82aNXrppZc0YMCA0+6XnZ2tSZMmaeTIkcrMzNTMmTN1//33KzU1tZWSnpvGksMaVgAAtDjLS86RI0d0ww036OWXX1b79u1Pu+/cuXPVqVMnPfPMM+rdu7duv/123Xbbbfq///u/Vkp7bvrGhsjTw9CB8moVllVZHQcAALdmecm55557dOmll+riiy8+474rV67U+PHjm2ybMGGCMjIyVFt78ieWqqurVV5e3uRllXY+nuoZFSRJWp9bYlkOAADaAktLzvz587Vu3TrNmjXLrv0LCwsVFRXVZFtUVJTq6upUXFx80u+ZNWuWQkJCGl/x8fHnnPtcHJ/5eD0rkgMA0KIsKzm5ubl64IEH9Pbbb8vPz8/u7zMMo8nXpmmedPtxM2bMUFlZWeMrNzf37EM7wKAfViRn8DEAAC3Ly6pfvHbtWhUVFSkpKalxW319vZYtW6bnn39e1dXV8vT0bPI90dHRKiwsbLKtqKhIXl5eCg8PP+nv8fX1la+vr+P/Amfp+JWcrLwy1dtMeXqcvJwBAIBzY1nJGTt2rLKysppsu/XWW9WrVy898sgjJxQcSUpJSdEnn3zSZNuSJUuUnJwsb2/vFs3rKN0jg+Tv46kj1XXac/CIuv8wRgcAADiWZbergoKC1K9fvyavgIAAhYeHq1+/fpIabjXdfPPNjd8zdepU7du3T9OmTdPWrVv16quvat68eXr44Yet+ms0m6eHoX4dG25ZMSkgAAAtx/Knq06noKBAOTk5jV8nJCTos88+U1pamgYNGqS//OUvmj17tq6++moLUzbfYObLAQCgxVl2u+pk0tLSmnz9+uuvn7DP6NGjtW7dutYJ1EJ+fMKq1NIcAAC4M6e+kuOujpecbQUVqqqttzYMAABuipJjgdgQP3UI9FWdzdTmfOsmJwQAwJ1RcixgGAbz5QAA0MIoORYZGBcqicHHAAC0FEqORY6Py+FKDgAALYOSY5HjV3L2Hjqq0qM11oYBAMANUXIsEuLvrcQOAZJ4lBwAgJZAybHQj7esWJEcAABHo+RYaGDcD09YMfgYAACHo+RY6KeDj03TtDYMAABuhpJjod4xwfL2NHSoskb7S45ZHQcAALdCybGQn7enescES+KWFQAAjkbJsVjjpIA8YQUAgENRcizGE1YAALQMSo7FBv1QcrLyylRXb7M2DAAAboSSY7HEDgEK8vXSsdp67ThwxOo4AAC4DUqOxTw8DA2IZ74cAAAcjZLjBBh8DACA41FynMDxwcesYQUAgONQcpzA8cHHOw5U6GhNnbVhAABwE5QcJxAV7KfoYD/ZTGlTXrnVcQAAcAuUHCdx/GrOmr2HrQ0CAICboOQ4iVE9IiRJb6zYq6raeovTAADg+ig5TuLqpI7qGNpORRXVenPlXqvjAADg8ig5TsLXy1MPXtxdkvRi2m5VVNVanAgAANdGyXEiVw2JU7fIQJUerdXLy7OtjgMAgEuj5DgRTw9D08f1kCTNW75Hh45UW5wIAADXRclxMpf0i1b/jiGqrKnXnLTdVscBAMBlUXKcjGEYenhCT0nSm6v2Kb/0mMWJAABwTZQcJzSqewednxCmmjqbnvt6p9VxAABwSZQcJ2QYhn73w9Wc/2bsV3ZxpcWJAABwPZQcJ5XcJUxjekWq3mbq6aU7rI4DAIDLoeQ4sYfHN1zNWbQhX1vyWdMKAIDmoOQ4sT6xwbp8YKwk6d9LtlucBgAA10LJcXIPXdxdnh6GvtpWpLX7WLwTAAB7UXKcXGJEoK5JipMk/fPz7TJN0+JEAAC4BkqOC7h/bHf5eHno++zDWr6z2Oo4AAC4BEqOC4gNbaebhnWWJP3rC67mAABgD0qOi7j7wq4K8PFUVl6ZvthcaHUcAACcHiXHRYQH+uo3IxIkSf+3ZIfqbVzNAQDgdCg5LuT2UYkK9ffWrqIjWpiZZ3UcAACcGiXHhQT7eeuu0V0lSU8v3aHqunqLEwEA4LwoOS7m5pQuigzyVV7pMb2/JtfqOAAAOC1Kjotp5+Op+8Z2lyTN/mqXjtbUWZwIAADnRMlxQdcmx6tTmL+Kj1Tr9RV7rY4DAIBTouS4IB8vDz00ruFqzty03So7VmtxIgAAnA8lx0VdMbCjekQFqryqTi8v22N1HAAAnA4lx0V5ehiaPr6nJOnV77J1sKLa4kQAADgXSo4LG98nSgPjQ3W0pl4vfLPL6jgAADgVSo4LMwxDv5/QcDXn3e9ztL/kqMWJAABwHpQcF3dBtw4a3jVcNfU2zf5qp9VxAABwGpQcN/DwD1dzPly7Xxl7D1ucBgAA50DJcQNDOrXXVUM6ymZKD76/XhVVPFIOAAAlx0386Yq+ig9rp/0lx/T4x5utjgMAgOUoOW4iyM9bT/9qkDwMaUFmnhZtyLc6EgAAlqLkuJHkLmG6d0zDTMiPLsxSXukxixMBAGAdSo6buX9MNw3uFKqKqjo99P561dtMqyMBAGAJSo6b8fL00DPXDlKAj6dWZx/Wf5bttjoSAACWoOS4oc7hAXriir6SpKeW7NDG/aXWBgIAwAKUHDf1y6Q4TeofrTqbqQfnr9fRmjqrIwEA0KooOW7KMAz9/Rf9FR3spz3FlfrL/7ZaHQkAgFZFyXFjof4+euragTIM6b3VOfpic6HVkQAAaDWUHDc3vGsH3TkyUZL0h9SNKiqvsjgRAACtg5LTBkwb30N9Y4NVcrRW0z/YIBuPlQMA2gBLS86cOXM0YMAABQcHKzg4WCkpKVq8ePEp909LS5NhGCe8tm3b1oqpXY+vl6eevW6Q/Lw9tHxnsV5bsdfqSAAAtDhLS05cXJyefPJJZWRkKCMjQ2PGjNHkyZO1efPp117avn27CgoKGl/du3dvpcSuq1tkkB69tI8k6R+Lt2lrQbnFiQAAaFmWlpzLL79ckyZNUo8ePdSjRw/97W9/U2BgoFatWnXa74uMjFR0dHTjy9PTs5USu7Ybz++ksb0iVVNv04Pz16uqtt7qSAAAtBinGZNTX1+v+fPnq7KyUikpKafdd/DgwYqJidHYsWP1zTffnHbf6upqlZeXN3m1VYZh6B+/HKAOgb7afqBCTy7mNh8AwH1ZXnKysrIUGBgoX19fTZ06VQsXLlSfPn1Oum9MTIxeeuklpaamasGCBerZs6fGjh2rZcuWnfLnz5o1SyEhIY2v+Pj4lvqruIQOgb761zUDJEmvr9irtO1FFicCAKBlGKZpWvqoTU1NjXJyclRaWqrU1FS98sorSk9PP2XR+bnLL79chmFo0aJFJ32/urpa1dXVjV+Xl5crPj5eZWVlCg4OdsjfwRU9sWizXl+xVx0CffXFgyMVHuhrdSQAAE6pvLxcISEhzfr8tvxKjo+Pj7p166bk5GTNmjVLAwcO1LPPPmv39w8bNkw7d+485fu+vr6NT28df0H6w8Re6hEVqOIj1XokdaMs7roAADic5SXn50zTbHLl5UwyMzMVExPTgonck5+3p565drB8PD305dYivbs6x+pIAAA4lJeVv3zmzJmaOHGi4uPjVVFRofnz5ystLU2ff/65JGnGjBnKy8vTm2++KUl65pln1KVLF/Xt21c1NTV6++23lZqaqtTUVCv/Gi6rT2ywfn9JT/310636y/+26PyEcHWLDLQ6FgAADmFpyTlw4IBuuukmFRQUKCQkRAMGDNDnn3+ucePGSZIKCgqUk/PjFYaamho9/PDDysvLU7t27dS3b199+umnmjRpklV/BZd32wUJStt+UN/uKtZdb6/VgruHK8jP2+pYAACcM8sHHre2sxm45O6Kyqt0+fPf6kB5tS7uHamXbkqWh4dhdSwAABq55MBjWC8y2E8v3ZQsH6+G8TlPLd1hdSQAAM4ZJQeSpIHxofrH1f0lSc9/s0ufbMi3OBEAAOeGkoNGvxgcpztHJUqSfvfhBm3KK7M4EQAAZ4+SgyYeuaSXRveIUFWtTXe+maGDFfY/zg8AgDOh5KAJTw9Ds68frMQOAcovq9Jdb69VTZ3N6lgAADQbJQcnCGnnrZenJCvI10sZ+0r0+KJNzIgMAHA5lBycVNeIQM2+frAMQ3pvda7eWrXP6kgAADQLJQendFGvSD1ySS9J0p8+2aIVu4stTgQAgP0oOTit345K1JWDYlVvM3XPO+uUe/io1ZEAALALJQenZRiGnrx6gAbEhajkaK3ueDNDldV1VscCAOCMKDk4Iz9vT/3npiR1CPTVtsIKTfvvetlsDEQGADg3Sg7sEhPSTv+5KUk+nh76YvMBPfvVTqsjAQBwWs0uOZ9//rm+/fbbxq9feOEFDRo0SL/+9a9VUlLi0HBwLkmd2+uvv+gnSXr2q51anFVgcSIAAE6t2SXnd7/7ncrLyyVJWVlZmj59uiZNmqQ9e/Zo2rRpDg8I5/Kr5HjddkGCJGnafzdoa0G5xYkAADi5Zpec7Oxs9enTR5KUmpqqyy67TH//+9/14osvavHixQ4PCOczc1IvjezeQcdq63XHmxk6XFljdSQAAE7Q7JLj4+Ojo0cbHiP+8ssvNX78eElSWFhY4xUeuDcvTw89d/1gdQ731/6SY7r7nbWqrWfpBwCAc2l2yRkxYoSmTZumv/zlL1q9erUuvfRSSdKOHTsUFxfn8IBwTqH+Pnrl5mQF+npp1Z7D+vMnW6yOBABAE80uOc8//7y8vLz04Ycfas6cOerYsaMkafHixbrkkkscHhDOq3tUkJ65dpAMQ3pr1T7N+zbb6kgAADQyzDa28mJ5eblCQkJUVlam4OBgq+O4hbnpu/Xk4m0yDOmFXw/RpP4xVkcCALiZs/n8bvaVnHXr1ikrK6vx648//lhXXnmlZs6cqZoaBqC2Rb8dlaibUzrLNKUH31+vNXsPWx0JAIDml5zf/va32rFjhyRpz549uu666+Tv768PPvhAv//97x0eEM7PMAw9fnlfjesTpZo6m25/I0O7io5YHQsA0MY1u+Ts2LFDgwYNkiR98MEHGjVqlN599129/vrrSk1NdXQ+uAhPD0OzrxuswZ1CVXasVlNeXa2iiiqrYwEA2rBmlxzTNGWzNTwu/OWXX2rSpEmSpPj4eBUXFzs2HVxKOx9PvXJzsrqE+yuv9Jhue30Ni3kCACzT7JKTnJysv/71r3rrrbeUnp7e+Ah5dna2oqKiHB4QriU80Fdv3HaewgN8tCmvXHe/s445dAAAlmh2yXnmmWe0bt063XvvvXr00UfVrVs3SdKHH36o4cOHOzwgXE/n8ADNu2Wo/Lw9lL7joB5buElt7CE+AIATcNgj5FVVVfL09JS3t7cjflyL4RHy1vPllgO6860M2UzpoYt76IGLu1sdCQDgos7m89vrbH/Z2rVrtXXrVhmGod69e2vIkCFn+6Pgpi7uE6W/XNlPjy7cpKe/3KGYUD/9Kjne6lgAgDai2SWnqKhI1157rdLT0xUaGirTNFVWVqaLLrpI8+fPV0REREvkhIu64fzOyi89phe+2a0ZC7IUFeyn0T04RwAALa/ZY3Luu+8+VVRUaPPmzTp8+LBKSkq0adMmlZeX6/7772+JjHBxD4/vqV8M7qh6m6m7316rTXllVkcCALQBzR6TExISoi+//FJDhw5tsn316tUaP368SktLHZnP4RiTY42aOptufX21vtt1SBFBvlp493DFtfe3OhYAwEW0yrIONpvtpIOLvb29G+fPAX7Ox8tDc25MUq/oIB2sqNYtr61R6VGWAQEAtJxml5wxY8bogQceUH5+fuO2vLw8PfTQQxo7dqxDw8G9BPt567VbhyomxE+7io7ozjfXqqq23upYAAA31eyS8/zzz6uiokJdunRR165d1a1bNyUkJKiiokLPPfdcS2SEG4kJaafXbz1PQX5eWr33sKb/d4NsNubQAQA43lnPk7N06VJt27ZNpmmqT58+uvjiix2drUUwJsc5rNhdrCmvrlZtvanbRyToscv6WB0JAODEzubz22GTAboKSo7z+Hh9nh6Yv16S9NilvXX7yERrAwEAnFaLTQY4e/Zsu0PwGDnsNXlQRxWUVenJxdv010+3KqSdt65hskAAgIPYdSUnISHBvh9mGNqzZ885h2pJXMlxLqZp6m+fbtUr32bLw5BevCFJl/SLtjoWAMDJcLvKDpQc52Oaph5J3aj/ZuyXj6eHXr1lqEZ072B1LACAE2mVeXIARzMMQ7OuGqCJ/aJVU2/TnW9laF1OidWxAAAujpIDp+DpYeiZ6wZpZPcOOlpTr1teXa1theVWxwIAuDBKDpyGr5en/nNTkpI6t1d5VZ1umrdae4srrY4FAHBRlBw4FX8fL706ZWjj8g83zvtehWVVVscCALggSg6cToi/t976zfnqEu6v/SXHdOO873W4knWuAADNc1ZPV5WWlmr16tUqKio6YVHOm2++2WHhWgJPV7mO/SVH9cs5K1VYXqUBcSF65/bzFeR34uKwAAD31yqPkH/yySe64YYbVFlZqaCgIBmG8eMPMwwdPny4ealbGSXHtewqqtCv/rNKhytrNCwxTK/fep78vD2tjgUAaGWt8gj59OnTddttt6miokKlpaUqKSlpfDl7wYHr6RYZpDduPU+Bvl5ateew7n13nWrrbWf+RgBAm9fskpOXl6f7779f/v7+LZEHOEH/uBDNm5IsXy8Pfbm1SA9/wMrlAIAza3bJmTBhgjIyMloiC3BK5yeGa86NQ+TlYejj9fl6fNFmtbHJugEAzWTXAp0/demll+p3v/udtmzZov79+8vbu+lA0CuuuMJh4YCfGtMrSv/+1UA9+P56vbVqn0LaeevhCT2tjgUAcFLNHnjs4XHqiz+GYai+vv6cQ7UkBh67vne+36dHF26SJM2c1Et3jupqcSIAQEtrlYHHNpvtlC9nLzhwDzec31m/v6ThCs7fP9um+atzLE4EAHBGTAYIl3T3hd3029GJkqQZC7O0MHO/xYkAAM7GrjE5s2fP1p133ik/Pz/Nnj37tPvef//9DgkGnMkfLumlI1V1euf7HE3/7wZJ0i8Gx1mcCgDgLOwak5OQkKCMjAyFh4crISHh1D/MMLRnzx6HBnQ0xuS4F5vN1GMfb9K73+fIMKR/XzNQVw2h6ACAuzmbz2+7ruRkZ2ef9M+A1Tw8DP11cj8ZUsMVnQ82yDSlq5MoOgDQ1jEmBy7Pw8PQXyb3043DOsk0pYc/3KAP1zJGBwDaumbPkyNJ+/fv16JFi5STk6OamqarQz/11FMOCQY0x/GiI0lvr8rR7z7cIJtp6lfJ8RYnAwBYpdkl56uvvtIVV1yhhIQEbd++Xf369dPevXtlmqaGDBnSEhkBuxhGQ9ExZOitVfv0SOpGyZR+NZSiAwBtUbNvV82YMUPTp0/Xpk2b5Ofnp9TUVOXm5mr06NG65pprWiIjYDfDMPTnyX11c0pnmab0yIKN+u+aXKtjAQAs0OySs3XrVk2ZMkWS5OXlpWPHjikwMFB//vOf9Y9//MPhAYHmMgxDf7qir6b8UHR+n7pR769hwkAAaGuaXXICAgJUXV0tSYqNjdXu3bsb3ysuLnZcMuAcGIahJ67oq1uGd5EkPZKaxczIANDGNHtMzrBhw/Tdd9+pT58+uvTSSzV9+nRlZWVpwYIFGjZsWEtkBM6KYRh6/PI+kqTXV+zVHxZkyZR0/XmdrA0GAGgVzS45Tz31lI4cOSJJeuKJJ3TkyBG9//776tatm55++mmHBwTOxfGiYxjSa9/t1YwFWTJN6dfnU3QAwN01q+TU19crNzdXAwYMkCT5+/vrxRdfbJFggKMYhqH/d1kfGTL06nfZmrkwS6ZM3XB+Z6ujAQBaULPG5Hh6emrChAkqLS1toThAyzAMQ3+8rLd+M6JhWZJHF27S26v2WZwKANCSmj3wuH///g5bn2rOnDkaMGCAgoODFRwcrJSUFC1evPi035Oenq6kpCT5+fkpMTFRc+fOdUgWuD/DMPTYpb11+w9F57GPNuktig4AuK1ml5y//e1vevjhh/W///1PBQUFKi8vb/Jqjri4OD355JPKyMhQRkaGxowZo8mTJ2vz5s0n3T87O1uTJk3SyJEjlZmZqZkzZ+r+++9Xampqc/8aaKMMw9Cjl/bWHSMbis4fP9qkN1bstTYUAKBF2LUK+U95ePzYiwzDaPyzaZoyDEP19fXnFCgsLEz/+te/9Jvf/OaE9x555BEtWrRIW7dubdw2depUbdiwQStXrrTr57MKOaSG83XW4m16aVnDVcn7x3TTQ+N6NDmnAQDOo8VWIf+pb775ptnB7FFfX68PPvhAlZWVSklJOek+K1eu1Pjx45tsmzBhgubNm6fa2lp5e3uf8D3V1dWN8/pIavbVJrgnwzA0Y2Iv+ft46pkvd2r217tUVFGtv17ZT16erFsLAO6g2SUnISFB8fHxJ/yP1zRN5eY2f/r8rKwspaSkqKqqSoGBgVq4cKH69Olz0n0LCwsVFRXVZFtUVJTq6upUXFysmJiYE75n1qxZ+tOf/tTsXHB/hmHowYt7KDLIT499lKX5a3JVfKRaz10/RO18PK2OBwA4R83+L2tCQoIOHjx4wvbDhw8rISGh2QF69uyp9evXa9WqVbrrrrs0ZcoUbdmy5ZT7n6xcnWz7cTNmzFBZWVnj62yKGNzbr8/vpDk3JsnXy0Nfbi3SDa+sUklljdWxAADnqNkl5/jYm587cuSI/Pz8mh3Ax8dH3bp1U3JysmbNmqWBAwfq2WefPem+0dHRKiwsbLKtqKhIXl5eCg8PP+n3+Pr6Nj69dfwF/NyEvtF6+/bzFeznpXU5pbrmPyuVV3rM6lgAgHNg9+2qadOmSfphvpE//lH+/v6N79XX1+v777/XoEGDzjmQaZpNxtD8VEpKij755JMm25YsWaLk5OSTjscBmmNolzB9eNdwTXl1tXYVHdFVL36nN247T72iKcYA4IrsLjmZmZmSGkpIVlaWfHx8Gt/z8fHRwIED9fDDDzfrl8+cOVMTJ05UfHy8KioqNH/+fKWlpenzzz+X1HCrKS8vT2+++aakhiepnn/+eU2bNk133HGHVq5cqXnz5um9995r1u8FTqVHVJBSfyg6O4uO6Jq5K/XKzck6P/HkVwoBAM7L7pJz/KmqW2+9Vc8++6xDbvscOHBAN910kwoKChQSEqIBAwbo888/17hx4yRJBQUFysn5ceXohIQEffbZZ3rooYf0wgsvKDY2VrNnz9bVV199zlmA42JD2+nDqcN1+5trtGZviW56dbWevXaQJvY/cWA7AMB5NXueHFfHPDmwV1Vtve5/L1NLthyQYUh/vqKvbkrpYnUsAGiTzubzmwlBgFPw8/bUnBuT9OvzO8k0pT9+vFn/98V2tbH/FwCAy6LkAKfh6WHob1f200MX95AkPf/NLj2SulF19TaLkwEAzoSSA5yBYRh64OLumnVVf3kY0n8z9uu3b63VsZpzW8IEANCyKDmAna4/r5Pm/jBp4FfbivRrJg0EAKdGyQGaYXzfaL1z+/kKaeetzJxSXT13hfYdqrQ6FgDgJCg5QDMldwnTh1NTFBvipz0HK3X5c9/qm21FVscCAPwMJQc4C92jgrTwngs0pFOoyqvqdNsba/Tslztls/HkFQA4C0oOcJaigv00/84U3Tis4RHzp7/coTvezFDZsVqrowEARMkBzomPl4f+emV//euXA+Tzw4Dkyc9/q+2FFVZHA4A2j5IDOMA1yfFKnTpcHUPbae+ho7ryhe/0yYZ8q2MBQJtGyQEcpH9ciD65b4Qu6BauY7X1uu+9TP3t0y1MHAgAFqHkAA4UFuCjN249T1NHd5Ukvbw8WzfO+17FR6otTgYAbQ8lB3AwL08P/WFiL825YYgCfDy1as9hXf7ct1qfW2p1NABoUyg5QAuZ2D9GH91zgRI7BKigrEq/mrtS763OsToWALQZlBygBXWPCtLH916gcX2iVFNv04wFWfpD6kZV17HuFQC0NEoO0MKC/Lz1nxuT9LsJPWUY0vw1ufrV3JXKLz1mdTQAcGuUHKAVeHgYuueibnr91vMU0s5bG/aX6fLnvtV3u4qtjgYAbouSA7Si0T0i9L/7RqhPTLAOVdboxnnf6y//26KqWm5fAYCjUXKAVhYf5q/Uu4br+vPiZZrSvG+zdens5drA01cA4FCUHMAC7Xw8NeuqAXr1lmRFBPlq98FKXTVnhf69ZLtq6pg8EAAcgZIDWGhMrygteXCULh8Yq3qbqee+3qUrX/hO2wrLrY4GAC6PkgNYrH2Aj567frCe//Vgtff31paCcl3x3Heak7Zb9TbT6ngA4LIoOYCTuGxArL54aJTG9opUTb1N//h8m66Zu0LZxZVWRwMAl0TJAZxIZJCfXpmSrH/+coACfb20LqdUE59dpjdW7JWNqzoA0CyUHMDJGIahXyXH6/MHRyolMVxVtTY9vmizbnr1e+UxgSAA2I2SAzipuPb+euf28/XE5X3k5+2h73Yd0iVPL9MHGbkyTa7qAMCZUHIAJ+bhYeiWCxL02f0jNbhTqCqq6/S7DzfqjjfXqqiiyup4AODUKDmAC0iMCNQHv03R7y/pKW9PQ19uPaAJTy9T6tr9XNUBgFOg5AAuwsvTQ3df2E2L7h2h3jHBKjlaq+kfbNC1L63S9sIKq+MBgNOh5AAupndMsD6+5wL9/pKeauftqdXZhzVp9nL97dMtOlJdZ3U8AHAalBzABfl4NVzVWTptlMb3iVK9zdTLy7N18b/T9enGAm5hAYAoOYBLi2vvr5duTtZrtwxVpzB/FZZX6Z531+nmV1drz8EjVscDAEtRcgA3cFGvSC15aJQeGNtdPl4eWr6zWJc8s1z/XrJdVbX1VscDAEtQcgA34eftqYfG9dCSB0dpVI8I1dTb9NzXuzTu6XR9tfWA1fEAoNVRcgA306VDgN64dajm3DBEMSF+yj18TL95I0O3v5Gh3MNHrY4HAK2GkgO4IcMwNLF/jL6cNlq/HZUoL4+GuXXGPZ2uF77Zpeo6bmEBcH+G2cYewygvL1dISIjKysoUHBxsdRygVew4UKE/frRJ32cfliQlRgToT1f01cjuERYnAwD7nM3nNyUHaCNM09TH6/P110+3qvhItSTpwp4RmjGxt3pGB1mcDgBOj5JjB0oO2rqyY7V65ssdemvlPtXZTHkY0i+T4jRtXE9Fh/hZHQ8AToqSYwdKDtAgu7hS//pimz7LKpQk+Xl76I6Rifrt6K4K9PWyOB0ANEXJsQMlB2hq7b4S/f2zrVq7r0SS1CHQRw9c3EPXDY2XtyfPJgBwDpQcO1BygBOZpqkvNhfqH59vV3ZxpaSGwcl/uKSXxvWJkmEYFicE0NZRcuxAyQFOrbbepne/z9GzX+3U4coaSdJ5XcI089LeGhQfam04AG0aJccOlBzgzMqrajU3bbfmfZut6jqbJOmyATH6/YRe6hTub3E6AG0RJccOlBzAfgVlx/TvJTuUum6/TFPy9jR007Auum9MN7UP8LE6HoA2hJJjB0oO0Hxb8ss1a/FWLd9ZLEkK8vPS1NFddcvwLgrgSSwArYCSYwdKDnD2lu04qL9/tlXbCiskSWEBPpo6OlE3Deuidj6eFqcD4M4oOXag5ADnpt5m6pMN+Xrmyx3ae6hhwc+IIF/dfWFXXX9eJ/l5U3YAOB4lxw6UHMAx6uptWpCZp9lf7dT+kmOSpOhgP90zppuuTY6Xjxdz7ABwHEqOHSg5gGPV1Nn0wdpcPf/1LhWUVUmSOoa20/1ju+mqIXFMKAjAISg5dqDkAC2jqrZe81fn6IW03TpY0bAAaOdwfz0wtrsmD+ooTw8mFARw9ig5dqDkAC3rWE293vl+n+ak7dahHyYU7BoRoAcv7qFL+8fIg7ID4CxQcuxAyQFaR2V1nd5YuVf/Sd+jsmO1kqSeUUF6aFx3TegbzVIRAJqFkmMHSg7QuiqqavXqt3v1yvI9qqiukyT1iQnWvWO6aULfaG5jAbALJccOlBzAGmVHa/Xy8j167btsVdbUS5ISOwRo6oVddeWgjjyNBeC0KDl2oOQA1jpcWaPXV+zV699lq7yq4cpObIif7hyVqGuHdmJSQQAnRcmxAyUHcA5Hquv0zqp9euXb7ManscIDfHTbiATdlNJZwX7eFicE4EwoOXag5ADOpaq2Xh+u3a+56bsbJxUM8vXSzcM769YLEtQh0NfihACcASXHDpQcwDnV1dv0ycZ8zUnbrR0HjkiS/Lw9dN3QTrpjVKI6hrazOCEAK1Fy7EDJAZybzWbqy60H9ELabm3ILZUkeXkY+sXgjpp6YVd1jQi0NiAAS1By7EDJAVyDaZpasfuQXvhml1bsPiRJMgxpYr9o3TW6m/rHhVicEEBrouTYgZIDuJ7MnBK9mLZbS7ccaNw2LDFMd4xM1EU9I5lFGWgDKDl2oOQArmt7YYX+k75bizbkq87W8E9X14gA3TEyUVcO7ig/bx4/B9wVJccOlBzA9RWUHdPrK/bq3VU5jbModwj00c0pXXTjsM4KC/CxOCEARzubz29LpxidNWuWhg4dqqCgIEVGRurKK6/U9u3bT/s9aWlpMgzjhNe2bdtaKTUAq8WEtNOMib21YsYYPXZpb3UMbafiIzV6aukODX/yKz32UZayiyutjgnAYpZeybnkkkt03XXXaejQoaqrq9Ojjz6qrKwsbdmyRQEBASf9nrS0NF100UXavn17kyYXEREhT88zX6rmSg7gfurqbfpsU6FeWrZbm/LKJTUMUh7fJ0p3jkpUUucwixMCOFcuf7vq4MGDioyMVHp6ukaNGnXSfY6XnJKSEoWGhjb7d1ByAPdlmqZW7Tmsl5fv0dfbihq3D+4UqjtHJmo8C4ICLutsPr+9WjhTs5SVlUmSwsLO/L+uwYMHq6qqSn369NFjjz2miy666KT7VVdXq7q6uvHr8vJyx4QF4HQMw1BK13CldA3XrqIKvbI8WwvW5Skzp1R3vbNOncL8ddsFXXR1UpyCWDYCcHtOcyXHNE1NnjxZJSUlWr58+Sn32759u5YtW6akpCRVV1frrbfe0ty5c5WWlnbSqz9PPPGE/vSnP52wnSs5QNtQVFGlt1bu01ur9qn0aK0kyd/HU78Y3FE3Duus3jH8OwC4Ape+XXXPPffo008/1bfffqu4uLhmfe/ll18uwzC0aNGiE9472ZWc+Ph4Sg7QxhytqdOHa/frjRV7tfvgj4OSh3ZprxuHddYl/aLl68Uj6ICzctnbVffdd58WLVqkZcuWNbvgSNKwYcP09ttvn/Q9X19f+fqywB/Q1vn7eOnmlC66aVhnrdxzSG+v2qclmw9ozd4SrdlbovAAH107NF6/Pr+T4tr7Wx0XgANYWnJM09R9992nhQsXKi0tTQkJCWf1czIzMxUTE+PgdADckWEYGt61g4Z37aAD5VWavzpX767epwPl1Xoxbbfmpu/WmF6RunFYZ43qHsFsyoALs/R21d133613331XH3/8sXr27Nm4PSQkRO3aNaw4PGPGDOXl5enNN9+UJD3zzDPq0qWL+vbtq5qaGr399tt68sknlZqaqquuuuqMv5OnqwD8XG29TV9tPaC3Vu3Td7sONW7vFOavG4d10jVJ8WrPBIOApVxuTI5hnPx/SK+99ppuueUWSdItt9yivXv3Ki0tTZL0z3/+Uy+99JLy8vLUrl079e3bVzNmzNCkSZPs+p2UHACns/vgEb2zKkcfrM1VRVXDbMo+Xh66bECMbhrWWYPiQ0/5bxeAluNyJccKlBwA9jhaU6dPNuTrzZX7tDn/x6kn+sQE6/rz4jV5cEcF8xg60GooOXag5ABoDtM0tT63VG+vytEnG/NVU2eTJPl5e+jS/rG6/rx4JXVuz9UdoIVRcuxAyQFwtkqP1mjBujzNX5OjHQeONG7vFhmo64bG66ohcSwOCrQQSo4dKDkAzpVpmlqXU6r5q3P0v40FOlZbL0ny8fTQ+L5Ruv68TkpJDOfJLMCBKDl2oOQAcKSKqlp9vD5f89fkNC4OKkmdw/31q+R4XZMUp8hgPwsTAu6BkmMHSg6AlrIpr0zvrc7Rx+vzdaS64cksTw9DY3tF6vrzOmlUjwgWCAXOEiXHDpQcAC3taE2d/rexQPNX52hdTmnj9pgQP/1icEddnRSnrhGB1gUEXBAlxw6UHACtaceBCr23OkcLM/MaFwiVpEHxobo6KU6XD4hRqD+DlYEzoeTYgZIDwApVtfX6cusBLViXp/QdB1Vva/in18fTQ2N7R+rqIXEa3TNC3p4eFicFnBMlxw6UHABWK6qo0qL1+fpw7X5tK6xo3B4e4KMrBsXq6iFx6hsbzNw7wE9QcuxAyQHgTDbnl2nBujx9vD5PxUdqGrf3ig7S1UPiNHlQLE9nAaLk2IWSA8AZ1dbbtGzHQS1Yl6elWw6opr5hZmUPQxrVI0JXD4nTuD5R8vP2tDgpYA1Kjh0oOQCcXdnRWn2yMV+p6/Yr8ydPZwX5eumSftG6cnBHDUsM53F0tCmUHDtQcgC4kt0Hj2jhujwtzMxTXumxxu1Rwb66fECsrhzckfE7aBMoOXag5ABwRTabqTV7D+uj9fn6LKtAZcd+fBy9a0SAfjG4oyYP6qj4MH8LUwIth5JjB0oOAFdXXVev9O0H9fH6fC3deqBxZXRJSurcXlcOitWlA2JZLBRuhZJjB0oOAHdSXlWrzzcV6uP1eVqx+5CO/4vu5WFodI8ITR7cUeN6R6mdDwOW4dooOXag5ABwVwfKq/TJhnx9tD6vyWKhAT6emtA3WpcOiNGI7h3k60Xhgeuh5NiBkgOgLdhVVKGPMvP18YY85R7+ccBykK+XLu4TpUn9YzSyewceSYfLoOTYgZIDoC0xTVPrckr0yYYCLd5UoAPl1Y3vBfp66eLekZrYP0aje0RQeODUKDl2oOQAaKtstobC82lWgRZnFaqwvKrxvQAfT43tHaVJ/aN1Yc9ICg+cDiXHDpQcAGgoPJm5pfosq0CLswqUX/Zj4fH38dSYXpGa1D9GF/WMZNAynAIlxw6UHABoymYztWF/Q+H5LKuwyaSD7bwbCs8l/aJ1Yc8IBfl5W5gUbRklxw6UHAA4NdM0tWF/mRZnFejTrALtL/mx8Hh7GhretYPG9YnSuD5RimLhULQiSo4dKDkAYB/TNJWVV6bPsgq1ZHOh9hRXNnl/UHyoxvWJ0oS+UeoaEcjSEmhRlBw7UHIA4OzsKjqiJVsKtXTLgSYLh0pSYocAjesbpfF9ojQ4vr08WDwUDkbJsQMlBwDOXVF5lZZuPaClWw5oxa5Dqqn/cWmJDoG+GtcnUuP6RGl4V+bigWNQcuxAyQEAx6qoqlX6joNauuWAvt5WpIqqusb3/H08NbpHhC7uHaULe0YoPNDXwqRwZZQcO1ByAKDl1NTZ9H32IS3Z3HCV56dz8RiGNDg+VGN7R2ls70j1jApiHA/sRsmxAyUHAFrH8YHLS7cc0Fdbi7SloLzJ+x1D22lMr0iN6R2plMRwbmvhtCg5dqDkAIA1CsqO6ettRfp6a5G+3VWs6rofx/G08/bUiO4dNLZXpMb0ilQkj6fjZyg5dqDkAID1jtXUa8XuYn31Q+n56W0tSerfMURje0dqbK8o9Y0N5mktUHLsQckBAOdimqY255fr621F+mpbkTbkljZ5PyLIV6N7RGh0jwiN7N5Bof4+1gSFpSg5dqDkAIBzK6qoUtq2g/pq2wEt31msozX1je95GA2TEI7uEanRPSPUv2OIPLnK0yZQcuxAyQEA11FdV6812SVK31Gk9B0HtePAkSbvt/f31qjGqzwRigjiEXV3RcmxAyUHAFxXfukxpe84qPTtB/XdrmJVVNc1eb9/x5CGW1s9IzQ4PlRenh4WJYWjUXLsQMkBAPdQW29TZk6p0ncUKW37QW3Ob/qIepCfl0Z276AR3RrG8sSH+VuUFI5AybEDJQcA3FNRRZWW7yhW+o6DWrbzoEqP1jZ5v1OYv0Z076CR3ToopWs4A5hdDCXHDpQcAHB/9TZTG/eXKn1Hw22tzJxS1dl+/LgzDGlAxxBd0K2DRnTvoKTO7eXrxWSEzoySYwdKDgC0PUeq6/T9nkNavrNY3+0q1s6ipgOY/bw9dF5CuEZ266ALunVQr+gg5uZxMpQcO1ByAACFZVX6blexvv3hdbCiusn7HQJ9NLxrB4344dZWXPt2rLNlMUqOHSg5AICfMk1TOw4c0fKdDbe2vs8+3GRuHqlhna3zE8M0LDFcKYmUHitQcuxAyQEAnE5NnU2ZOSX6dlexVu4+pA37S1Vb3/SjktLT+ig5dqDkAACa42hNndbtK9WqPYe0ag+lxyqUHDtQcgAA58Ke0hMb4qdhieE6PzFM5yWEq0u4P6XnHFFy7EDJAQA4kj2lJyLIV+clhOm8LmE6LyFMPaN4equ5KDl2oOQAAFrST0vP6uzDWp9bqpp6W5N9Qtp5a2iX9g3FJyFcfWOD5c0SFKdFybEDJQcA0Jqqauu1cX+ZVmcf0vfZh7V2X8kJT2+18/ZUUufjpSdMg+JD5efN5IQ/RcmxAyUHAGClunqbNueXa3X2YX2ffVgZ+w6fsASFj6eHBsSFKKlLeyV3DlNS5/YKC2jby1BQcuxAyQEAOBObzdTOoiONV3pWZx9W0c8mJ5SkrhEBDYWnS3sld26vhA4BbWowMyXHDpQcAIAzM01T+w4d1Zq9Dbe21uw9rN0HK0/YLzzAR0M6t9fQLu2V1DlM/ToGu/X6W5QcO1ByAACupqSyRmv3lShjX4nW7jusDfvLVFPXdDCzj5eHBsaFKKlzmJI7t9cQN7vFRcmxAyUHAODqquvqtSmvXGv3HdaavSVau69EhytrTtgvoUOABncK1eBO7TWkU6h6RgXJy0Wf4qLk2IGSAwBwN6ZpKru4suFKz94SZew7+S0ufx9PDYgL0ZBO7TWkU3sN7hSq8EBfCxI3HyXHDpQcAEBbUHq0RutzS7Uup1SZOSVan1Oqiuq6E/brHO7/Q+lpuOLTK9o5r/ZQcuxAyQEAtEU2m6ldB49o3b4SrcspUWZOqXYWHTlhv3benuofF6KBcSEaEBeqgXGhig+zfi0uSo4dKDkAADQoO1ar9bkNV3qOX/GpqDrxak+ov7f6dwzRwLhQDYgL0cD4UEUF+7VqVkqOHSg5AACcnM1mak/xEa3LKVXW/jJt3F+qrQUVJyxLIUmRQb4/XOkJ+eHKT6jat+DTXJQcO1ByAACwX3VdvXYUHtGG/aXauL9UG/eXaceBCtlO0h7iw9ppQFyoBnQM0W0jEhy6Hhclxw6UHAAAzs2xmnptzi/Thv1lyvqh+Owp/vFprvAAH2U8drFDx/Gczee3l8N+OwAAaBPa+XgquUuYkruENW4rO1arTXll2ri/TDbTtHygskTJAQAADhDSzlsXdOugC7p1sDpKI+d7EB4AAMABKDkAAMAtUXIAAIBbouQAAAC3ZGnJmTVrloYOHaqgoCBFRkbqyiuv1Pbt28/4fenp6UpKSpKfn58SExM1d+7cVkgLAABciaUlJz09Xffcc49WrVqlpUuXqq6uTuPHj1dl5Ykrpx6XnZ2tSZMmaeTIkcrMzNTMmTN1//33KzU1tRWTAwAAZ+dUkwEePHhQkZGRSk9P16hRo066zyOPPKJFixZp69atjdumTp2qDRs2aOXKlWf8HUwGCACA6zmbz2+nGpNTVlYmSQoLCzvlPitXrtT48eObbJswYYIyMjJUW1t7wv7V1dUqLy9v8gIAAO7PaUqOaZqaNm2aRowYoX79+p1yv8LCQkVFRTXZFhUVpbq6OhUXF5+w/6xZsxQSEtL4io+Pd3h2AADgfJym5Nx7773auHGj3nvvvTPu+/Opoo/fcTvZFNIzZsxQWVlZ4ys3N9cxgQEAgFNzimUd7rvvPi1atEjLli1TXFzcafeNjo5WYWFhk21FRUXy8vJSeHj4Cfv7+vrK19fXoXkBAIDzs/RKjmmauvfee7VgwQJ9/fXXSkhIOOP3pKSkaOnSpU22LVmyRMnJyfL29m6pqAAAwMVYWnLuuecevf3223r33XcVFBSkwsJCFRYW6tixY437zJgxQzfffHPj11OnTtW+ffs0bdo0bd26Va+++qrmzZunhx9+2Iq/AgAAcFKW3q6aM2eOJOnCCy9ssv21117TLbfcIkkqKChQTk5O43sJCQn67LPP9NBDD+mFF15QbGysZs+erauvvtqu33l8/A5PWQEA4DqOf243Z+Ybp5onpzXs37+fJ6wAAHBRubm5Zxy/e1ybKzk2m035+fkKCgo66dNY56K8vFzx8fHKzc1losFm4Lg1H8fs7HDczg7H7exw3JrvdMfMNE1VVFQoNjZWHh72jbZxiqerWpOHh4fdDfBsBQcHc0KfBY5b83HMzg7H7exw3M4Ox635TnXMQkJCmvVznGaeHAAAAEei5AAAALdEyXEgX19fPf7440w+2Ewct+bjmJ0djtvZ4bidHY5b8zn6mLW5gccAAKBt4EoOAABwS5QcAADglig5AADALVFyAACAW6LkOMiLL76ohIQE+fn5KSkpScuXL7c6klN74oknZBhGk1d0dLTVsZzOsmXLdPnllys2NlaGYeijjz5q8r5pmnriiScUGxurdu3a6cILL9TmzZutCetEznTcbrnllhPOv2HDhlkT1knMmjVLQ4cOVVBQkCIjI3XllVdq+/btTfbhfDuRPceN8+1Ec+bM0YABAxon/UtJSdHixYsb33fUuUbJcYD3339fDz74oB599FFlZmZq5MiRmjhxYpOFRXGivn37qqCgoPGVlZVldSSnU1lZqYEDB+r5558/6fv//Oc/9dRTT+n555/XmjVrFB0drXHjxqmioqKVkzqXMx03SbrkkkuanH+fffZZKyZ0Punp6brnnnu0atUqLV26VHV1dRo/frwqKysb9+F8O5E9x03ifPu5uLg4Pfnkk8rIyFBGRobGjBmjyZMnNxYZh51rJs7ZeeedZ06dOrXJtl69epl/+MMfLErk/B5//HFz4MCBVsdwKZLMhQsXNn5ts9nM6Oho88knn2zcVlVVZYaEhJhz5861IKFz+vlxM03TnDJlijl58mRL8riKoqIiU5KZnp5umibnm71+ftxMk/PNXu3btzdfeeUVh55rXMk5RzU1NVq7dq3Gjx/fZPv48eO1YsUKi1K5hp07dyo2NlYJCQm67rrrtGfPHqsjuZTs7GwVFhY2Ofd8fX01evRozj07pKWlKTIyUj169NAdd9yhoqIiqyM5lbKyMklSWFiYJM43e/38uB3H+XZq9fX1mj9/viorK5WSkuLQc42Sc46Ki4tVX1+vqKioJtujoqJUWFhoUSrnd/755+vNN9/UF198oZdfflmFhYUaPny4Dh06ZHU0l3H8/OLca76JEyfqnXfe0ddff61///vfWrNmjcaMGaPq6mqrozkF0zQ1bdo0jRgxQv369ZPE+WaPkx03ifPtVLKyshQYGChfX19NnTpVCxcuVJ8+fRx6rrW5VchbimEYTb42TfOEbfjRxIkTG//cv39/paSkqGvXrnrjjTc0bdo0C5O5Hs695rv22msb/9yvXz8lJyerc+fO+vTTT3XVVVdZmMw53Hvvvdq4caO+/fbbE97jfDu1Ux03zreT69mzp9avX6/S0lKlpqZqypQpSk9Pb3zfEecaV3LOUYcOHeTp6XlCuywqKjqhheLUAgIC1L9/f+3cudPqKC7j+NNonHvnLiYmRp07d+b8k3Tfffdp0aJF+uabbxQXF9e4nfPt9E513E6G862Bj4+PunXrpuTkZM2aNUsDBw7Us88+69BzjZJzjnx8fJSUlKSlS5c22b506VINHz7colSup7q6Wlu3blVMTIzVUVxGQkKCoqOjm5x7NTU1Sk9P59xrpkOHDik3N7dNn3+maeree+/VggUL9PXXXyshIaHJ+5xvJ3em43YynG8nZ5qmqqurHXuuOWhQdJs2f/5809vb25w3b565ZcsW88EHHzQDAgLMvXv3Wh3NaU2fPt1MS0sz9+zZY65atcq87LLLzKCgII7Zz1RUVJiZmZlmZmamKcl86qmnzMzMTHPfvn2maZrmk08+aYaEhJgLFiwws7KyzOuvv96MiYkxy8vLLU5urdMdt4qKCnP69OnmihUrzOzsbPObb74xU1JSzI4dO7bp43bXXXeZISEhZlpamllQUND4Onr0aOM+nG8nOtNx43w7uRkzZpjLli0zs7OzzY0bN5ozZ840PTw8zCVLlpim6bhzjZLjIC+88ILZuXNn08fHxxwyZEiTxwdxomuvvdaMiYkxvb29zdjYWPOqq64yN2/ebHUsp/PNN9+Ykk54TZkyxTTNhsd6H3/8cTM6Otr09fU1R40aZWZlZVkb2gmc7rgdPXrUHD9+vBkREWF6e3ubnTp1MqdMmWLm5ORYHdtSJztekszXXnutcR/OtxOd6bhxvp3cbbfd1viZGRERYY4dO7ax4Jim4841wzRN8yyvLAEAADgtxuQAAAC3RMkBAABuiZIDAADcEiUHAAC4JUoOAABwS5QcAADglig5AADALVFyAACAW6LkAAAAt0TJAQAAbomSAwAA3BIlBwAAuKX/D6hnFBqf6pyHAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(loss_history)\n",
    "plt.ylabel('train loss')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2eced05a",
   "metadata": {},
   "source": [
    "## 4. 效果测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "553d19c7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "你好\n"
     ]
    }
   ],
   "source": [
    "model = GPT().to(device)\n",
    "model.load_state_dict(torch.load('model/gpt_chat.pt')) # 加载训练好的模型\n",
    "model.eval()\n",
    "ask = \"你好啊\"\n",
    "print(model.answer(ask))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "e9d4a651",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "金香菇和金香\n"
     ]
    }
   ],
   "source": [
    "ask = \"你叫什么名字\"\n",
    "print(model.answer(ask))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8bd17408",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "今天天气好了吗\n"
     ]
    }
   ],
   "source": [
    "ask = \"今天天气不错\"\n",
    "print(model.answer(ask))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
